{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 编译 PyTorch 目标检测模型\n",
        "\n",
        "本文是使用 Relay VM 部署 PyTorch 目标检测模型的介绍性教程。\n",
        "\n",
        "首先应该安装 PyTorch。TorchVision 也是必需的，因为将使用它作为模型动物园。\n",
        "\n",
        "快速的解决方案是通过 pip 安装：\n",
        "\n",
        "```bash\n",
        "pip install torch torchvision\n",
        "```\n",
        "\n",
        "或者请参考 [官方网站](https://pytorch.org/get-started/locally/)。\n",
        "\n",
        "PyTorch 版本应该向后兼容，但应该与正确的 TorchVision 版本一起使用。\n",
        "\n",
        "目前，TVM 支持 PyTorch 1.7 和 1.4。其他版本可能不稳定。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [],
      "source": [
        "import env"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\n",
        "from tvm import relay\n",
        "from tvm import relay\n",
        "from tvm.runtime.vm import VirtualMachine\n",
        "from tvm.contrib.download import download_testdata\n",
        "\n",
        "import numpy as np\n",
        "from cv2 import cv2\n",
        "\n",
        "# PyTorch imports\n",
        "import torch\n",
        "import torchvision"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 从 torchvision 加载预训练的 maskrcnn 并进行跟踪"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [],
      "source": [
        "def do_trace(model, inp):\n",
        "    model_trace = torch.jit.trace(model, inp)\n",
        "    model_trace.eval()\n",
        "    return model_trace\n",
        "\n",
        "\n",
        "def dict_to_tuple(out_dict):\n",
        "    if \"masks\" in out_dict.keys():\n",
        "        return out_dict[\"boxes\"], out_dict[\"scores\"], out_dict[\"labels\"], out_dict[\"masks\"]\n",
        "    return out_dict[\"boxes\"], out_dict[\"scores\"], out_dict[\"labels\"]\n",
        "\n",
        "\n",
        "class TraceWrapper(torch.nn.Module):\n",
        "    def __init__(self, model):\n",
        "        super().__init__()\n",
        "        self.model = model\n",
        "\n",
        "    def forward(self, inp):\n",
        "        out = self.model(inp)\n",
        "        return dict_to_tuple(out[0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "in_size = 300\n",
        "input_shape = (1, 3, in_size, in_size)\n",
        "\n",
        "model_func = torchvision.models.detection.maskrcnn_resnet50_fpn\n",
        "\n",
        "model = TraceWrapper(model_func(pretrained=True))\n",
        "model.eval()\n",
        "inp = torch.rand(input_shape)\n",
        "\n",
        "with torch.no_grad():\n",
        "    out = model(inp)\n",
        "    script_module = do_trace(model, inp)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 下载测试图像并进行预处理"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "img_url = (\n",
        "    \"https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg\"\n",
        ")\n",
        "img_path = download_testdata(img_url, \n",
        "            \"test_street_small.jpg\", \n",
        "            module=\"data\")\n",
        "\n",
        "img = cv2.imread(img_path).astype(\"float32\")\n",
        "img = cv2.resize(img, (in_size, in_size))\n",
        "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
        "img = np.transpose(img / 255.0, [2, 0, 1])\n",
        "img = np.expand_dims(img, axis=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 导入 graph 到 Relay"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "input_name = \"input0\"\n",
        "shape_list = [(input_name, input_shape)]\n",
        "mod, params = relay.frontend.from_pytorch(script_module, shape_list)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 使用 Relay VM 编译\n",
        "\n",
        "```{note}\n",
        "目前只支持 CPU target。对于 x86 target，由于在 torchvision rcnn 模型中存在较大的 dense 算子，因此强烈推荐使用 Intel MKL 和 Intel OpenMP 构建 TVM 以获得最佳性能。\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Add \"-libs=mkl\" to get best performance on x86 target.\n",
        "# For x86 machine supports AVX512, the complete target is\n",
        "# \"llvm -mcpu=skylake-avx512 -libs=mkl\"\n",
        "target = \"llvm\"\n",
        "\n",
        "with tvm.transform.PassContext(opt_level=3,\n",
        "                               disabled_pass=[\"FoldScaleAxis\"]):\n",
        "    vm_exec = relay.vm.compile(mod, target=target, params=params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 使用 Relay VM 推理"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dev = tvm.cpu()\n",
        "vm = VirtualMachine(vm_exec, dev)\n",
        "vm.set_input(\"main\", **{input_name: img})\n",
        "tvm_res = vm.run()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 获得得分大于 0.9 的 boxes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Get 9 valid boxes\n"
          ]
        }
      ],
      "source": [
        "score_threshold = 0.9\n",
        "boxes = tvm_res[0].numpy().tolist()\n",
        "valid_boxes = []\n",
        "for i, score in enumerate(tvm_res[1].numpy().tolist()):\n",
        "    if score > score_threshold:\n",
        "        valid_boxes.append(boxes[i])\n",
        "    else:\n",
        "        break\n",
        "\n",
        "print(\"Get {} valid boxes\".format(len(valid_boxes)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "interpreter": {
      "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
    },
    "kernelspec": {
      "display_name": "Python 3.8.10 64-bit",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
